{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Expaction-Maximization\n",
    "\n",
    "通常在一些机器学习应用里，我们都是给定一些数据，利用极大似然估计出模型的参数。但在某些情况下我们有一部分数据是未知的，模型也是未知的，这个时候可以通过 EM 算法来估计出隐含数据和模型参数、\n",
    "\n",
    "理解 EM 算法需要理解两个过程，一个过程是已知所有数据，进而估计模型参数；另一个过程是已知模型和一部分数据，计算出隐含数据。\n",
    "\n",
    "EM 的做法是启发式的、迭代式的：\n",
    "\n",
    "1. 随机假定一个模型，然后使用这个模型计算出隐含数据 （E step）\n",
    "2. 基于隐含数据，利用极大似然估计出新的模型参数（M step）\n",
    "\n",
    "这两步骤不断迭代，直至收敛。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preparation\n",
    "cluster1 = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], 50)\n",
    "cluster2 = np.random.multivariate_normal([3, 4], [[1, 0], [0, 1]], 50)\n",
    "cluster3 = np.random.multivariate_normal([7, 7], [[1, 0], [0, 1]], 50)\n",
    "data = np.vstack([cluster1, cluster2, cluster3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(data):\n",
    "    data = np.array(data)\n",
    "    plt.scatter(data[:, 0], data[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWkAAAD2CAYAAAAUPHZsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAedUlEQVR4nO3df2yd1XkH8O9j+ya5CWCHEinBIQVElXWAyw9XRYQgNVnNOjeQsii0FI1q0yJVq3BpFRakEtxqUkL5g5m20sRQNbS1g5Qh09aq4g4qpaQwmpDGdN0i2lBKTOiCwE5JnPhHzv64vva91+95f9z3xznve74fqYp97837nluT5x4/5znPEaUUiIjITi2mB0BERHoM0kREFmOQJiKyGIM0EZHFGKSJiCzWluTFLrroInXppZcmeUkiosI7ePDgO0qpFV7PJRqkL730Uhw4cCDJSxIRFZ6IvKF7jukOIiKLMUgTEVmMQZqIyGIM0kREFmOQJiKyGIM0EZHFGKSJiCzGIE1EZDEGaSIfQ0eH0PN0D7qe6ELP0z0YOjpkekjOcf1nkOiOQ6IiGTo6hP6f9+PMzBkAwPFTx9H/834AQO/lvQZHlr6ho0MYeGUAb596GyuXrUTfdX1G3rPLP4MqzqSJNAZeGZgLDlVnZs5g4JUBQyPKRjUwHj91HApqLjCamMG6+jOoxSBNpPH2qbcjPV4UNgVGV38GtRikiTRWLlsZ6fGisCkwuvozqMUgTaTRd10flrQuqXtsSesS9F3XZ2hE2bApMLr6M6jFIE2k0Xt5L/pv7MeqZasgEKxatgr9N/YXfsHKpsDo6s+gliilErtYd3e3Yj9pogpbKiSakeex55GIHFRKdXs9xxI8ohTkvXSs9/LeXIzTBUx3EKXApgoJyjcGaaIU2FQhQfkWGKRFZJmIPCsi+0XkG1kMiijvbKqQoHwLM5P+HICXlFLrAFwpIh9OeUxEuWdThQTlW5iFw7MAloqIAFgCYDLdIRHlX3XRjRUSFFdgCZ6IlAC8COA8AM8ppf6u4fltALYBwJo1a65/4w3tyeREZCGW25nnV4IXJkjvBPCWUupxEfl3AN9USv3c67WskybKl8ZSQQAotZSwtG0pTk6eZNDOSNw66fMBVH+CZ1GZURNRAXiVCk6dm8L45DiA/NV3F1GYhcNvA/iCiLwIoAzguXSHRERZCVMSyPpuswJn0kqp3wFYl/5QiChrK5etxPFTxwNfl2Z9N3Pi/riZhSjH4h4t5VUq6CWt+m6bDhiwFYM0UU4lEeAau8x1LO5Am9T/gp1mfTe3zwdjgyWimEz9uu4X4KLcv7GZUpbvh9vngzFIE8WQdbe7oaND2P3yboydHdO+Jm6Ay7IDni4nzu3z85juIIohy1/Xh44O4YH9D/gGaCBfAY7b54NxJk0UQ5a/rg+8MoCpc1O+rwkT4GyqpuD2+WAM0kQxZPnrelDgb5GWulm8V6Cz8TACHjDgj+kOohiy/HU9KPCfU+cAwLfKg9UU+cMgTRRDlgel9l3Xh1JLKdRrdYFXNxs/fup4rHprSg/THUQxNfPrejN54erztdUd7Yva5/psNPIKyH47DKuP25ACoXk8LZwoY16d55a0LkH/jf0Aoi+i9Tzd4xl4Vy1bhb7r+uqud/Pqm/Hsb55dkPLwsmrZKgxvGY725qgpsVqVRsEgTRRMF1Q7FnfgzPQZz+DtF6h1Qf+2K25bEJCrj+87tm8ucOtm1gLByN0jzbxFisgvSDMnTWaN7AEeuQro76j8ObLH9IiaEqWHhi4vPHZ2rKlFPV1efN+xfZ7X23dsH4a3DGPk7hEMbxnGqmWrPK+bp3rrImNOmswZ2QP88B5gaqLy/fible8BoGuruXFFFLWsLWznuaowNddeefEdP9vh+drGe/dd1+c5E6+tULGpttrG8aSJM2ky57mvzwfoqqmJyuM5ErWsTVe2176o3fP1zc5oW8T7n3fj40EVKrZ1qrNtPGnjTJrMGT8W7XFLRd11qNtlByBwRhtFtW7a63GvmahukTCpRk5JsW08aWOQJnPaV1dSHF6P50gzuw5r0xO1AbN9cTsWty5O5HzBVctWaRcoo6RnbOtUZ9t40sZ0B5mzcSdQKtc/VipXHs+ROLsOG391Hzs7hrMzZ7Fr/S4MbxmONTP0GleppaRdoNz1X7s8r6P7sDG1sGjbeNLGIE3mdG0FNj0KtF8CQCp/bno0V4uGQLxdh2lu0/Zq6O9Xcjs+Oe6Z17WtU50t44l7Kk5YrJMmMqjriS4oLPw3mEaNsq4+u5ZuA0tW1RRh72O6usNvQ1Iz4/Crk2ZOmsggm7roVV+jC4BpB8EopYymO+dluXjJdAeRQTZ10QOA9sXtxsrb8tShL8vFSwZpMqsgOw6b1Uw+u9lcaNDJ4Etal0ApZSxQ5qlqI8vFSwZpMqe643D8TQBqfsehg4G6dps2AG0Q9trI8dUXvor1T64PDNqNHwjti9rRsbij7sPh5ORJz78bJ1CG/VDJU9VGlr8BceGQzHnkKk2d9CXAvb/KfjwWCFqQCrP4F2cBy6+jXjMd8aIssCXdHTBtSS5esgse2am/A/CobAAE6Pc/bLWogoKkrhpE93ovjcHl5tU3z3XFu2DRBTg9fbruLMU4QX/9k+s9D86NUkUCeO/ETOtwBRNY3UF2KsiOwyQF5WXDNmfSXcerguKpI0/NPT8+OY42aUPH4g6Mnx2PNUMcOjqkPdncb8t84716nu5xaht4I+akyZyC7DhMUlBeNmjxL+g6XhUUjabVNMpt5bkcebOB0G+xMUqeOU8LimlgkCZzCrLjMElBC1K9l/fititum+tkJxC0NPwz9lvAChvYjp86Hrvszu9eURbY8rSgmAamO8isrq1OB+VGug55tc2Ynv3Ns3Md7hQU2lracH7b+aGaMkXpZR33nEPdvdoXtUe6Zt91fXhg/wN1efJSS8nYtvSsMUgTZSyoKsBvN51XumLq3BSWlpbihc++EHhvrwb/OnHzvrrDBO7/2P2Rr9VY4JBkwYPtmO6geY5vLMlC3Ib1cfOzXptn7lh7R+T7NXuvZioyBl4ZwLSarntsWk1buRMxDZxJU0VBjrKyXdyeD0n0+vCaqe87ti+VHiJJ9NjgwiERUJijrGwXN+CktdPNlvafXlxfOGSQpoqCHGVlu7gBJ6kUQlbXTYLNHyBZYLqDKrixJBNhTuYOklabTtPtP3WCKl6KLlSQFpH7AGwC8D6A25RSk6mOirK3cWd9ThpwfmNJGlwPOM2y9QMkC4FBWkQuB3ClUmq9iNwDYDWAo6mPjLJVXRx87uuVFEf76kqA5qJh4lwOOBRdmJn0RgDLRWQfgD8A+Ga6QyJjuLGEyDphFg5XADihlLoZlVn0TbVPisg2ETkgIgdOnDiRxhjJC2uaiZwQJkifBHBk9uujADprn1RKPaaU6lZKda9YsSLp8ZEXNssnckaYIH0QwEdnv74CzEebx5pmImcEBmml1IsA3hGRXwA4opR6Of1hkS/WNBM5I1QJnlLqC2kPhCJgTXMkSR5zZNO9yA3czJJHrGn2N7JnrpRwaMVq9J+/CGdUpc1ltaER0HwLTh2vU0/Suhe5g9vC84jN8vUaFlUHFs/MBeiqakOjpPk1TyJqFmfSeWWiprlmhpr5Zpew925YVH27rdXzcml0UHO9WxulgzNpCietsr8w9d5R7t2weLpyesbztml0UHO9Wxulg0Gawkmj7C9s8I1y74bF0773xrDk3Lm6x9LqoJZEt7aho0PoeboHXU90oefpntjnDFL+MUhTOGmU/YUNvlHu3XACee+p0+h/732sKrWn3oIzbrvPuKe2uMaVDzTmpCmcNMr+wgbfKPf2aBTVu3EnejPKncdpnhT31BaXuFRJw5k0hdMwQwUQv+xPF+AbH496766twL2/AvrHKn/mpOrFtoVHm2eqLlXSMEhTOGmU/YUNvo6UHCa58Bg3wNqeerHtAy1NTHfQvKAyt6TL/nQ9rIFKpUfjOAoWlBslcWoLkEwqwETqJcpuzSQO5M0LzqSpwlRnvcbUBOBMh7/G2S6ARM4ZTCIVkPVMNerM3aVzDzmTpoqgSousNrH4jaNAM2ndbLf/xn4MbxmOde0kAmzWM9WoM3eXjiFjkKYKbaXFm/V9QqrfA+kEzYw6/JluhJRmOiGJAJtU6iWsZj5YXDmGjOkOVzXu9Csv936dtGbbuzpsxUcMNiyKpZlOSCIVELfmOyru1tRjkHaRV/558n2gpVT/ulIZUN7bqjH+ZjrHdqVR6tfAhvKtNINSUgG29/JeDG8ZxsjdIxjeMpzqrNWlHHNUDNIu8sr7zkwCi89fWObWfon+OuNvAs/8LfDQZckF6wzK7Wwo30o7KGUZYJOQ9cw9T5iTdpEuvzvxHvD3ry98vLF39YK/925zeWpdyV/K5XY2lG+5tPAVlis55qgYpF3U9DZrj79TFbUCo5pyyWpBskbWi2I6DEoUBtMdLmpmm/XGnZVFRD9RKjAMHqbLX60pTziTdpFup59uBlud9eoWEauiVGCkUWoX4VACzmIpLxikXRUl7+s16/UyeaoSKMNcN+muegbTJy4xXV/uIqY7KFjY2W11ATFMpUfSpXYG0yeusKG+3EUM0i4Jc1SVF93s1itHHTYwJl1ql9FORZfZUF/uIqY78i5sHjZOOmDjzoVleKWyPgUSNjAmWWqXxqEEVMeG+nIXcSadZ1E618VJB+hmvbqNLiYCYwY7FV3HrdtmcCadZ1E6xsVNB+hmvV4z7I0762f41b4gE++l10UvasUKRWZLfblrGKTzLErg1aUDpCV8RUYjv6b9tcF74t2asaVYdeHAwQAmcZekGaKUSuxi3d3d6sCBA4ldjwI8cpUmD3vJfAP9qsacdK1SOdn+GLpxBY2RyFEiclAp1e31HHPSeRYlD1vNK8epyPBTWzkSFKAB/zRLs1UoRAXEdEeeRc3Ddm0Fntnm/VzcnX5BTZgata/2rkwBuCmFqAaDdN5FzcOmUaoWdkdiVakMfKjHOxi3eZT2FfD4LKKwmO5wTVKlaqHTGwKUL6z8r7Z877Vh72Bcu8hYq3rIAFMg5BjOpF0TJkXSmIb4UE8lqNZ+f/h7wbNnv8VBXdpFS+Y/DJgCIYewuoPqhcovC4CA/26CKkZ0FSDlC4HpiYb7a+7HChEqCFZ3UHih8st+ATpkHw5d2uWTDy3c3ai7X3Wxk9UgVGBMd1C9OFUeUWa2QWmX2gCvm3VLC/CjL9enXpgKoYIJHaRF5F4AvUqpP0txPM4ZPDSKh/cewVtjE7i4o4ztt6zF5ms7zQ1IV/0RpHVR9MXHsJUpXg2egMohBAe+gwUzbVaDUIGESneIyAcBfD7dobhn8NAo7n/mVYyOTUABGB2bwP3PvIrBQ6PmBuWVhghj0XnpBUW/jThBqRCinAubkx4AcH+aA3HRw3uPYGKq/kiqiakZPLz3SLQLJZmT9ep41/0389/rTLzX/D3Djivo+K5abFFKBRGY7hCROwEcBvBrzfPbAGwDgDVr1iQ6uKJ7a8x7gU73uKc4faJ1vaj90hDafiEZBEVpDReodXXfEc5AJLJFmJn0pwBsBPAkgOtF5Iu1TyqlHlNKdSululesWJHGGAvr4g7vtILucU/N9omO0ou6lsm+zX4BOuiEl2bfL5FhgUFaKXWnUuomAJ8BcFAp9a30h+WG7besRblUn2ctl1qx/Za14S/SbJ/oZoN70sdeRaE9ZGC2qqR/rPKn11h4BiLlFEvwDKpWccSq7mi2F0ecQwBM9W3WHeMVZhbPMxApp0IHaaXU7wCw/C5hm6/tjFdyFxS4dHlYbamdquSdbczXxjl9hWcgUk5xW3gR6AKx1xbv6nZtwH/7t9+2bpsX4Jr5/8KWsZOz/LaFM91RBLr0g18etroz8Lmve88wdRtC4lSTpC3M2Gz9cCHS4Ey6yPo74L3ZQyqLbFFfB0Q7sitrNo+NyAcbLLlKl29tfDzs6wC7F+BsHhtRkxikiyxsTXOU2ucoAT1rNo+NqEkM0kUWtqY5Su2zyc0sQWweG1GTmJOm6PJY3UFkMb+cNIM0LGwXSkROYQmej2q70Go3umq7UAAM1FniDJjIk/NB2q9dKIN0CryCMWBv7TWRYc4H6UTahVI4us0mbWX9phsGaXKc89UdibQLpXB0OyAn3vV+PeubiRikE2kXmjemTteOGnRZ30zEdEci7ULzxGTvDV0nuvKFwPREcy1IiQrOmRI8ltnN0vW3ACqbWNKsqgjqysfqDnKU8yV4LLOr4ZdyaJhVJ/7B5teJjsdYEXlyIkizzK6Gttn/rNmqisGZdel8sHm1VbW5/SmRYU4sHNpeZjd4aBTrdj+Py3YMYd3u5zF4aDS9m3n1t2g0fsz3gy1xPH+QSMuJIG1zmV01FTM6NgGF+RlraoG6rpmSRvvqbD/Y2GKUSMuJIG1zmZ1uxvqVPYfTm1l3ba00wb/9n7Vd4zL9YGOLUSItJ4L05ms7sev2q9HZUYYA6OwoY9ftVyeSj46bqtDNTGeUSn9m7dOiNNMPNrYYJdJypgQvSDOVDI1VI0AlkEX5AFi3+3mMhkghdHaUsX/HhlDXTEqmZYtssEQOY6vSAM0GW12AjRJQve7tRQC8vrs31DWJKF+cr5MO0myJXhKLa7U7Hv1m1KYXOY1vBuJMmxzFII3mg+3FHWXPwNpMQD09Oa19zvQip/HNQKyjJoc5sXCoM3hoFNd8bRi6hE9QsN1+y1qUWqTusVKLRAqo1QD43umpusdvbXkBLyy6B0cX34mD530Jm1v3h75m0jKtmfbCOmpymLMz6cFDo9j+/cOYOucdokPPXiXg+wBeAfDWlhewu/Q4lsokAGDpxPG5mePgzLrM0w7GNwOxjpoc5uxM+uG9R7QBulUkVIXGw3uPYGqm/hpTMyrSDNMr0N3XtmcuQM9feAKnf7wz240vszKtmfZqo8o6anKYs0HabxZ4TqlQs9MkZphege5iecfztUsm3s487TB4aBSnzi7Ml6eSJ6/mnsffBKDmc88f6mEdNTnL2SDtNwv0e65280qLeOc2vP6+btOL16aRt9RFntd969wHvB9PKe1QzZePTdTny5cvLSW2GaiOLvf82rB20w1R0Tmbk95+y1rPnHSpVb/w11jlMONRY944wxw8NIr+H/x3XaDzqo6ozTMf/9P7sPrVBxf0XX5c3QU0ZEGA9MrzvPLlALB0UVs6eXC/3LNX9zwiBzgXpGvrfdvLJUxOz+D01DkAlRnig5uu1AYgXdBqFcE5pRYs5PltVJmYmsGXnvolvrLnMD77sUsaNr9sAC5dvqAueProhyEv/b6uGiXN8rzMFwx1bVSZeyaHORWkGys6xiamUGoR/OMd18TKQZ9Tam43YDWtUX1t0H7OGaXwby/9HgDwD5uvnn+iYeY4eGgU/3Hw1brrCYC/vL4zteqOJOvAQ9m40/vkFuaeyWFObQu/5mvDC/KrANBRLuGXD/YE/n3dNvDlS0tYuqgNo2MTEAQHZi+tIvjtrr/wfG7w0Ci+suewZ3qleu80SvKS6E0SGXcWkoO4LXyWV4D2e7zR9lvWLghapVbB+2em5zajNPuR5xWAgflAqXv+vdNTc/eOshMwzDZvI4f0MvdMVCfXQTrrfhJeQevU2enQQd5Pq6ZSRJcH1wnTcyTKNu/N16aXTiGiYIFBWkQEwL8AWAvg/wDcrpTSN5rISDP9JJYvLS3Yfl19PKzGoHXZjqEow8aiVsHkzMJZ8Wc/tvCklMFDo6HamDYKWtjjmY9E+RGmTnodgDal1A0ALgAQnLzNQDP9JB7cdCVKrQ29NloFD266sulxRFlEK5da8Y0tH8FdN6yZmzm3iuCuG9bULxpi/kNIp1UEHWXvD5egMRnf5k1EoYVJd/wBwMDs1x5VumY0E2jSyLF65amri4cd5RJEgLHTU3X32nxt54Kg3MgvzVFdvAPgubAXVJKXedUGETUtMEgrpV4DABH5NIBFAPbWPi8i2wBsA4A1a9akMERvzQaapHOsaS2u+X3YNFZXRL231wdL3Hpr4/2miQoqVAmeiNwK4MsANiml/qh7XZYleEbKwzKUxKkvfpIMqkX/WRClLdbxWSKyEsD3Afy5UuqU32uzrpOuDTQdS0tQChifmCrETC5K4DM9i037A4Wo6OLWSd8NYBWAvZVCD3xHKfWdBMfXtGrqwvjJISkIm0ax4b1zIZIoPWFy0g8BeCiDsTStqCVlYfLnNrx3LkQSpacQrUpdnsnZ8N692q2aPpeRqChyveOwyraZXBo5Yt01bXjvRraPEzmiEEE6jZKyZoTtHd3MdXV5Z1veO7ePE6WjEOmOzdd2YtftV6OzowxBpaog6/Iv3SkmQPwjroLyzqbfOxGlpxAzacD8TC6oEVKcHHFQ3tn0eyei9BRiJm2DoCAcJ0ec6WndRGQVBumE+AXMuDniXFdPjOwBHrkK6O+o/Dmyx/SIiHKFQTohXoEUSOZk7dzmnUf2VI7DGn8TgKr8+cN7GKiJInDq+Ky0t0+b3p5tnUeu0hwsewlw76+yHw+RpXh8FrLZPs0FvAbjx6I9TkQLOJPuaOaQANdVTz6/bMcQ1u1+HoOHRqNdoH11tMeJaAFngrQN26fzpPqbx+jYBBTmf/OIFKg37gRKDQuqpXLlcSIKxZkgndcyttiz2SYl8ptH11Zg06OVHDSk8uemR3kaOFEEzuSkbdk+HYXJNqSJ/ebRtZVBmSgGZ2bSccrYcj2bbVJef/MgKhorgnRWQXDztZ3Yv2MDHrnjGgDAvU/9MvB+ieRmm2Qyj57rDTREBWI8SGcdBKPez9XZbG430BAVjPGcdNYni0S9n+nZrMk8Ouu+icwzPpPOOghGvZ/p3OzitvkfURJbzIkoX4wH6ayDYNT7mcrNevWnPjN1LvTfNbHQSUTJMx6kkwiCUYJS1PtlkZv1Gn+zuXCTC51ElDwrGizFaUzUWEsMVIKuXyC1qRGSbvy6AwQEwOu7e7XXW7f7ec8zDzs7yti/Y0Ps8RJR8qxvsBRngaqZhUebFsR0428VwYzHB2hQGojb34mKxXi6I668ByXdOGeUaioNZHqhk4iSlfsgnfegpBtnNfcdNRfOTShExWJFuiMO07XEcfmNv5m0TPX1tuTciSie3AfpvAelNMZvU86diOKxorqDiMhlftUduc9JExEVGYM0EZHFcp+TTopNG1yIiKoYpGH2BBQiIj9Md4AniRORvRikkf9di0RUXAzSyP+uRSIqLgZpcCs1EdmLC4fI/65FIiou3yAtIksAPA3gEgAjAP5KJblF0SLcSk1ENgpKd9wF4JhS6iMAlgP4RPpDIiKiqqAgvQHAT2a/fh7AxxtfICLbROSAiBw4ceJE0uMjInJaUJD+AIDx2a9PAriw8QVKqceUUt1Kqe4VK1YkPT4iIqcFBel3ALTPft0++z0REWUkKEg/B6Bn9usNAH6a7nCIiKhWUJD+LoBOERkB8C4qQZuIiDLiW4KnlDoL4FMZjSUz7HhHRHnh3GYWdrwjojxxbls4O94RUZ44F6TZ8Y6I8sS5IM2Od0SUJ84FaXa8I6I8cW7hkB3viChPnAvSADveEVF+OJfuICLKEwZpIiKLMUgTEVmMQZqIyGIM0kREFpMkjywUkRMA3kjsgtm4CMXsk13E91XE9wTwfeVNGu/rg0opz1NTEg3SeSQiB5RS3abHkbQivq8ivieA7ytvsn5fTHcQEVmMQZqIyGIM0sBjpgeQkiK+ryK+J4DvK28yfV/O56SJiGzGmTQRkcUYpImILOZkkBaRJSLyIxE5LCL/KiJiekxJkYonROQlEfmBiBSm06GI3Csi/2l6HEkSkftE5Gci8mMRWWR6PEkQkWUi8qyI7BeRb5geTxJEpCQiP5z9OtP44WSQBnAXgGNKqY8AWA7gE4bHk6R1ANqUUjcAuABAj+HxJEJEPgjg86bHkSQRuRzAlUqp9QB+DGC14SEl5XMAXlJKrQNwpYh82PSA4hCRMoCDmI8TmcYPV4P0BgA/mf36eQAfNziWpP0BwMDs15MmB5KwAQD3mx5EwjYCWC4i+wCsB/C64fEk5SyApbMzzCXI+X+HSqkJpVQXgGOzD2UaP1wN0h8AMD779UkAFxocS6KUUq8ppV4WkU8DWARgr+kxxSUidwI4DODXpseSsBUATiilbkZlFn2T4fEk5XsAPgngfwD8r1Lqt4bHk7RM44erQfodAO2zX7ejYP0FRORWAH0ANimlZkyPJwGfQmXW+SSA60Xki4bHk5STAI7Mfn0UQFGOC7ofwD8ppf4EwIUicqPpASUs0/jhapB+DvO52g0AfmpwLIkSkZUAtgPoVUr90fR4kqCUulMpdROAzwA4qJT6lukxJeQggI/Ofn0FKoG6CM4HcGb267MAzjM4ljRkGj9cDdLfBdApIiMA3kXl//SiuBvAKgB7ReQFEflr0wMib0qpFwG8IyK/AHBEKfWy6TEl5NsAviAiLwIoo1j/voCM4wd3HBIRWczVmTQRUS4wSBMRWYxBmojIYgzSREQWY5AmIrIYgzQRkcX+H+n/UuG+eqeAAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for c in [cluster1, cluster2, cluster3]:\n",
    "    plot(c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KMeans:\n",
    "    def __init__(self, k, max_try=5, max_iters=100, tol=1e-8):\n",
    "        self.k = k\n",
    "        self.max_try = max_try\n",
    "        self.max_iters = max_iters\n",
    "        self.tol = tol\n",
    "        \n",
    "    def fit(self, x):\n",
    "        best_score, best_res = None, None\n",
    "        for _ in range(self.max_try):\n",
    "            score, res = self._fit_one_round(x)\n",
    "            if not best_score or score > best:\n",
    "                best_score = score\n",
    "                best_res = res\n",
    "        return res\n",
    "    \n",
    "    def _fit_one_round(self, x):\n",
    "        centers = x[np.random.choice(range(len(x)), size=self.k)]\n",
    "        for i in range(self.max_iters):\n",
    "            # E step\n",
    "            cls_dict = defaultdict(list)\n",
    "            for sample in x:\n",
    "                c = np.argmin([self._dist(sample, center) for center in centers])\n",
    "                cls_dict[c].append(sample)\n",
    "            # M step\n",
    "            new_centers = np.array([np.mean(samples, 0) for samples in cls_dict.values()])\n",
    "            dist = self._dist(centers, new_centers)\n",
    "            if dist < self.tol:\n",
    "                break\n",
    "            centers = new_centers\n",
    "        return -dist, (new_centers, cls_dict)\n",
    "    \n",
    "    def _dist(self, a, b):\n",
    "        return np.mean((a - b) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans = KMeans(k=3)\n",
    "centers, cls_dict = kmeans.fit(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWkAAAD2CAYAAAAUPHZsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAd+klEQVR4nO3df4yV1ZkH8O8zwwUu2M5AZQOOIhI32ljpgmPWiLgRUowZrdQltmubxexmSbptOtGECklrp00Tx5JsM/2RNKTb1DTtKloDrbMGWnCDUo3yI4xNK7tdrCtTbDF2oOLIzDBn/7hzh3vvvOf9/b7nvO/5fhrjeGe491ypz314znOeI0opEBGRndpML4CIiPQYpImILMYgTURkMQZpIiKLMUgTEVlsVppPdumll6ply5al+ZRERKV3+PDht5VSi7y+l2qQXrZsGQ4dOpTmUxIRlZ6IvKH7HssdREQWY5AmIrIYgzQRkcUYpImILMYgTURkMQZpIiKLMUgTEVmMQZqIyGIM0kQ+Bk8MYv1T67HisRVY/9R6DJ4YNL0k57j+e5DqiUOiMhk8MYi+X/Xh/QvvAwBOnTuFvl/1AQB6lvcYXFn2Bk8MYuDIAN469xYWz1+M3lW9Rt6zy78HdcykiTQGjgxMB4e69y+8j4EjA4ZWlI96YDx17hQU1HRgNJHBuvp70IhBmkjjrXNvRXq8LGwKjK7+HjRikCbSWDx/caTHy8KmwOjq70EjBmkijd5VvZjbPrfpsbntc9G7qtfQivJhU2B09fegEYM0kUbP8h703dyHJfOXQCBYMn8J+m7uK/2GlU2B0dXfg0ailErtybq7uxXnSRNdZEuXRFRFXXdRichhpVS31/fYgkeUkSK3j/Us77F+ja5guYMoIzZ1SVBxMUgTZcSmLgkqrsAgLSLzRWS3iBwUkW/ksSiiMrCpS4KKK0wm/WkALymlVgO4TkQ+nPGaiErBpi4JKq4wG4fnAcwTEQEwF8BYtksiKof6xhu7JCiJwBY8EakAeBHAJQD2KaU+1/L9zQA2A8DSpUtveOMN7c3kRGQhttuZ59eCFyZIPwzgD0qp74vIfwD4tlLqV14/yz5pomJpbRMEgEpbBfNmzcPZsbMM2jlJ2if9AQD138HzqGXURFQCXm2C45PjODN2BkCxervLKszG4XcBfFZEXgRQBbAv2yURUV7CtANm3dvt+lD/IIGZtFLq9wBWZ78UIsrb4vmLcercqcCfy6q3u8inMvPCwyxEBZY0C/VqE/SSVW83T2UG4+wOooJKIwttbRPsmNOBd8fexYSamP6ZLHu7eSozGIM0UUKmWtj8stAor986TCnP96Mrt/BU5kUM0kQJ5F1THTwxiP6X+zFyfkT7M0mz0Dwn4PWu6p3RAshTmc1YkyZKIM+a6uCJQXz54Jd9AzRQrCyUQ/2DMZMmSiDPmurAkQGMT477/kyYLNS2E4acXe2PmTRRAnlOugsK/G3SNp3F67o86uWZU+dOQUFNl2fYm2wvBmmiBPKcdBcU+CfVJAD4Bl62vBUPgzRRAnFrqnH6m3tX9aLSVgm1Ll3g9SvP8OSfnXgRLVHOvIYazW2fGzq4N3Z3dMzumJ6z0UogGNo01PTY+qfWe7a8dc7pxPsT78daEyXHi2iJLBJUcvDb1PPaZNMF3sXzF8/YJLz18lux+3e7ZwRjpVQqPdeUPpY7iHKmKznUa8lRN/V0dfFbL791xvPt/t1u3H313TPKM2fHzkZaK+WHmTQZtevoMLbvOY4/jIziss4qttx+DTas7DK9rMiitLXpTtnVuzMahclmdTfA6DL2AycPYO/GvU2PDxwZ8D35Z1vbnktYkyZjdh0dxranX8Xo+IXpx6qVdjxyz/WFCtRRa8y6n28NqHVeteUwrn/seu33Xt30aqg19d3cBwCxa+hZKduHhl9NmuUOMmb7nuNNARoARscvYPue44ZWFE/UtjZdR8iS+Us8fz5uz3WbeP/n7fW4X5eKbW17rvV6s9xBxvxhZDTS47aKc+qwcQOwnhV6lRuS9FzX+6a9Htdlol7ZqG2T6tIaLFUUzKTJmMs6q5Eet1WSU4eNWWGrpHMsdJl555zOGZno1ue34usvfd3z5/M8VRmGbR8aWWOQJmO23H4NqpX2pseqlXZsuf0aQyuKJ8mpQ6+sEKgF2L0b9ybKDL3WVWmrYOT8iOdrPnH8Cc+SQZ6nKsOw7UMjawzSZMyGlV145J7r0dVZhQDo6qwWbtMQSDbJLcussHVdnXM6EdQo4FVnznNSXZhTj7Z8aOR1QpPdHUQG6Q6i1DPpPF6rUdxOkjRE6ZIx3d2R5NSoF3Z3EFkqz6wwTHZeP6VoYoZHlC6SnuU92LtxL4Y2DSUuC8WRZ8cLuzvIqLIcZolLdxAli6ATdDN44ylFE7d3F2lDMM+1MpMmY+qHWYZHRqEADI+MYtvTr2LX0WHTS8tV1KwwbqbrdzN4vc584OQBYz3RRdoQzHOtDNJkTFkOs6TNLwh7HeT40gtfwprH1wQGba8NwP41/Xh106vTHw5ZZIhhP1Rs2RAMI8+1stxBxpTlMEuagi629aqFTqiJ6dGlQeWJoKuq0r69O8pFvX6lH9Mbha3yLFOxu4OMWd2/H8MeAbmrs4qDW9caWJF5Qd0eKx5bAYXg/2b9ukO8xpceOHkAb517Cx+c/UG8N/Fe012KSboW1jy+xvPi3CjdK2l3UtiI3R1kpbIcZklTULkhbEarex6vcskTx5+Y/uczY2eglELnnM7EPdGDJwa1N5tHKZ/YNjskbyx3kDH1Lg6XuztaBZUbelf1zsgqdc/jRXfCsdGEmkB1VhXPf+r5kKv25hdEo5RPitT1kQUGaTJqw8oup4NyK68g3Lgh1bO8B0f/dBRP/veTmFSTkKn/TWLS8+dbhQ1sp86dwuCJwUTlBL/XirLBlnadvGhY7iCySNAR7METg9j9u93TE+4UFNrb2tExuyNUeSJKYEs6/lP3Wh2zOyIFf68LeCttFSu7PrLATJqmuX6wJC9BnQp+HRhe5YrxyXHMq8zDC//wQuBrhy2XAMnHf+r+VLDtb7dFfq7WBoc0Gx5sx0yaAPBgSV6SDqxPWp/1ytQ/ec0nI79e3NeKswk5cGQAE2qi6bEJNcGNQ3KL38ESZtPpSTqwPo36rFemfuDkgUzqvkF92WG4vnHITJoA8GBJXpIGnKxOutl82q9Ix8WzwCBNAMpzS4rtkgacrGY75zkzOiqbP0DywBOHBKA8N3fbzoXTc1mw7Vh42vxOHIaqSYvIFwHcBeBdAHcrpcZSXB9ZgAdL8pHnzIcySaO2XVSBmbSILAfwFaXUJhH5AoBnlFInvH6WmTQRUXRJM+l1ABaIyAEAfwTw7ZYn3wxgMwAsXbo04VIpLPY0E7khzMbhIgCnlVK3ArgcwC2N31RK7VBKdSuluhctWpTFGqkFe5qJ3BEmSJ8FUJ/CfgIA0zXDOCyfyB1hgvRhADdOfX01aoGaDGJPM5E7AmvSSqkXReRtEXkFwG+VUi/nsC7ycVln1XNYPnua9VjDp6IKdZhFKfVZpdSNSql/zHpBFIzD8qPJs4Yf95JYIh3O7igg9jQHa8yc20RwoaXVNIu5JFHu8yMKi0G6oDgsX6/19GRrgK5Lu4afdHgSkRcGaQoti7pu2OeM8tpe3S9e0q7huz6tjbLBIE2htGan9bougNiBOuxzRn3tMBlyFjV81695omxwCh6FkkVvdtjnjPraugy5XQQCoKuzmsngqDSmtXHjkVoxk6ZQsujNDvucUV97y+3XGJnol3R4Ejceoyn7ZLw6BmkKJYve7LDPGfW1TXa/JJnWxo3H8Fz6QGO5g0LJojc77HPGee0NK7twcOtavN7fg4Nb1xaiEybNjcc0yiY2l178PtDKhpk0TfProMgiOw37nK70hae18ZhGlml7pupSJw1vZiEAdt3M4uoR7rRubVn/1HrPYL9k/hLs3bg3t+eIKkqN2cT6suQ3T5rlDgLg30Gx6+gwVvfvx1VbB7G6f3+mI1FdGsPaWk4AkMo9g2lkmXlnqvUPqFPnTkFBTWfuuhKLS/ceMkgTAH2nRD1I5hU0XRnDqgtKALB3414MbRrC3o17Y5UW0rhdO+8buqPWmG2+ODdtrEk7qrWk0FGtYGR0fMbPtYtog2YWJYi8xrCaLqlk2cnRu6rXs2wSJctM4zmiiJO5u3LvIYO0g7xO8FXaBZU2wfjkxT2KaqVde7w6q9nVeYxhzeL0ZFRZlhPSuOw27wtzeVpTj0HaQV4lhfELCgvmVTBv9qym7HL7nuOeQVMB+Juv7oUIMPLeeGrZqO4gSppHuP1KKnkF6ayDUhpZZp6Zat6Ze5EwSDtIlwWPvDeOow+vn/F4a9Cc/vmG8kha2Wge7XY23GzDoNQs78y9SBikHRSlpNAYNL1+TaOo2aiuLpz1GFYbbrZhUJrJlRpzVAzSDopaUqgHzau2DiKoqz5sNmqyLpxHSSUMBiUKgy14DtqwsguP3HM9ujqroafC7To6jDaRwOcOm42abLWL8/6JTGEm7agoJYV61qu74aROANx27aJQz5lFXThKWx1vtqGiYCZNgXQ3nbTm1QrATw8Phzroosu449aFXTqpaJLNQ5fKikGaAumyW6+8OmzJIu2peq6cVDQp6tFtSgeDtGPizOGImt2GKVmkXRe2oa2u7FwaD2oT1qQLLkodNm5Hha4bYs6sNs+j5GGDepp1YRva6srOpfGgNmEmXWBR67BxSwK6rLfv49eFLllkPUkvi0sJqFneQ5eohpl0gUU93pykJOCX9Xpl8o0Zfke1gnNjExi/UKtiZ9ET7crFACbxlKQZDNIFFjXo6koCHdVK7DV4Be/WsopXSSSLWRlsq8sWT0mawSBdYFHrsFtuvwZbnjzWNOkOAM6NTWDX0eHUApyuZa+V7sPE9BhR0uMpyfyxJl1gUeuwG1Z24ZK5Mz+Xxy+oxK1qjTXnoBkfdV4fJux3JmrGTLrA4tRhR96bWXoAkp/0003K06lW2nHbtYuwun//jNGopseIEtmEQbrgotZhs2hVC1PeqLQJLpk7a3r29G3XLsJPDw/PaAfM+5IBItsxSDsm7AS41rrwbdcuwnOvnW7KeoHgEaYCeGb4q/v3e2bM7SKeM0I6qpUZWTcza3KBqIChOVF0d3erQ4cOpfZ8lI2gjbkw5YtKmwCC6bY6L12dVRzcutbze35jT1uv7fJ6rWqlnZPrqDRE5LBSqtvre8ykHRRUIglTvmjtEGkVdJBEV3bpaqhN1z9E3hubwJ9baumsU5MrGKRphqT1364Q5Qi/skvrh8hVW70H+AyPjGJ1//6m0gvLIVQ2DNI0gy7LDcOvxNEoSmeK33qGR0ax5cljTeUQE7d/E2UldJ+0iDwgIr/McjHOGtoJfPMjQF9n7e9DO40ux6v/OoyoszI2rOzCwa1r8Xp/Dw5uXasNqEHrGZ9UM2rjHFNKZREqSIvIlQDuz3YpjhraCfz8C8CZNwGo2t9//gWjgdproNJnblqKLp82vXaRzDby6utZMC/a8XW27VEZhC13DADYBuDB1m+IyGYAmwFg6dKl6a3MFfu+Boy3BJPx0drjK+4N/TRpH6XWbS7qujImlcq0tLBhZRe27zk+YwPRT2vvN4+bUxEFBmkRuQ/AMQC/8fq+UmoHgB1ArQUv1dW54MzJaI97iDsnOk7QMjm32S8zrrTLjBa9xtKLydvJiZIIU+64E8A6AI8DuEFEPp/tkhzTcXm0xz3EmRMdd0aGybnNug+CBfMq2L7xo763vPB6LSqqwExaKXUfAIjIMgDfV0p9J+M1uWXdw7UadGPJo1KtPR5SnDnRcWdkmJzbrGvb+8pd1wX2fvN6LSoqtuCZVq877/tarcTRcXktQEeoR/uVIHQljawuAMhSkg8IXq9FRcVj4SXgdYy7WmnH39/Q1TTEqP74I/dc7ztzI+gwiu0bcF7rA+D574hHy8kGfsfCOU+6BHR3ED732mltScOv99ivPm37vGfd+gCkejs5UV6YSZeYrl1OALze3zOdcfpl1K2nB1f379fO3Ahz0jBrtq+PyAszaUfp6q31x+sn/kTz673q07ZvwNm+PqKoGKRLLGy7XFAwj/uzJti+PqKoGKRLTFerbq3DRul9NtknHYbt6yOKii14JRemXS5Ka5vJPukwbF8fUVTcOARqw4wS9CkTESXBm1n81KfQ1U/81afQAQzURGQcg3RKU+goHL+DMLYfkiEygUE6hSl0FI7fJDoAnFJH5IFBuuPyqYH7Ho9TqoIm0cUZ+ERUdmzBW/dwbepco4hT6ArH0HVdfgdNeAiFyBuD9Ip7gbu+BXRcAUBqf7/rW+WtRxu8rsvvoAkPoRB5c6vcoWu1q//lAoMbpbp50H5T6ngIhVznTiZt4YWvRmg3St+cUf7YdXQYq/v346qtg1jdvz/xpLugE5BzKxf/79hZrXBKHRFcyqTZalej2ygF0Pjh9crv/4xtr1yZereF1wlIr3nY5ycmY78GUZm4k0nb3GqX50ae10Zpq/FRXHFke253AvL+QSI9d4J0Che+ZiLvMkzrRqnGX6m3PR/PotuCnR1Eeu4EaVtb7XRlmGcfyi67XnEv8MCvgb6RqWA905/kUs/Hs+i2YGcHkZ47QTqrVrukpQpduWX0nXyya82H15urtuQ28pPjRYn03Nk4BIJb7aJOw0tjOJPvRl6DrDY5NbeV37jiXjxyRT6zNDhelEiPo0rrWgMuUMsw/bLtb35Ec6T8ilo5Ie7rakmtREFEpcJRpWHEadFLo2OkKZN9E5B2QF3w/lmTm5yGZ25zQh65yp2adJA4ATetjpEV9wJ/vR6A6AO0yU1OwweB6n3UwyOjULjYs530cA1RETBIP/Mg8NWFADRlH7+Am1bHyNBO4NAP9GuQduCj95k7dOP3p4wcsI+aXOZ2ueOZB4FD/67/flDAXXEv8H8vAYd/WMuA4wbTfV+DNkADtec+9hNg6U0Xfz7PsoPhg0DsoyaXuR2kD/9Q/72OK8J1dxz7ycUSRWMwjRI4wwS7eu/0xGj+V33lOHPbq/Z8WWcVwx4BmX3U5AK3yx26+i9Q684ICnxplQHCBrvRd/IvOwztBMbOzXw8gxq5rvZ827WL2EdNznI7SEt7tMdbRSkD+B16CTNPI846kqpvGI6+0/x4dWEmM7d1tefnXjvtOz2PqMzcLnfccL93TfqG+/W/prEVTdq8s/HWzPiZB5s3BlvLFK0HSqoLgLF3gQtjF5+jUgVmVWcGTK/XS4vXnxQAYPb8TMorfrVnr+l5RC5wL0i39vte9XfA71+4uPF3w/3Anf+m/7WNB0+8AnRrGUDXuTE+Cjz9L7W/qguBOx5tPgDj1ZcMALv+FZgcv/hzbZXsWvNy3jBk7ZloJreC9NDO5iB35k3gL28Bn/heuMxQl1lKO6Amvbstnn0Ivp0bQC073v252tf1X+t1hH1oJyAtk+ta/zlNOV/SG3RzC5GL3DoW/uhV3uWC6kLgodeDf31fJ/T91FdcLFUAwOifa197vZ6O7jj5dFatmfHh9yGRRJyj8gnxZCG5iMfC63QBM2wg1Q5DkouPNz5XlAAN6Dccg2Z71MsuUVrywhzz1gxfyrLdj7VnombFD9J5zpRY97BHwBQEljPC8ioj6EosOmGm5UWZ3ufSJb1EFip2C17UmRLVhdEeb+U1kzqtAN0+23sDMMwY0xm/JmBjz/AxbyIKLzBIS81jIvKSiPxMROzJvqMGmzserQXDRu2za4+H1XiryQO/1t5solWpAt3/3PzBUF0I3P1d741CnyuutP3cQRt7Nt/3SERNwgTc1QBmKaVuEpH/ArAewH9muqqwogabLGqsniWQBm0VYM4HahuJja+na/NrpJ3pIcA9O2pfem3sBbXk5dy1QUTxhQnSfwQwMPX1WOs3RWQzgM0AsHTp0vRWFkacYJN2jdXrIAowMyjHoc1sVfNzRv3Q8fpgSeOYt+GZ00RlFLoFT0Q+AaAXwDqlvIde5N6CZ6BFLFdp3Pyik3ZALfvvBVGG/FrwQgVpEfk4gAcB3KWU+ovu54z0STcGmzSzWBuEDXw2ZLBZfqAQlZxfkA6zcbgYwBYAPX4B2pj6Rt49O2pjPEffgYnbQzIR5oZzw7emTONmJFEmwtSkNwFYAmCP1I4g/0Ap9YNMVxVHnDsKiyCohm7L++ZmJFEmAjNppdSjSqmrlVK3TP1lX4AG3M3kbHnfaV0lRkRNin2YpVFal8KmxW9+dJrPZ8v7DlOaIaLI7DmYklRWbWVRNA1CajgunvSaK79j3Da87zoeISdKXXkyadOZXNMGHuA5PzruseugujMzWKLSKk8mDZjN5MIMQopbJw6qOzODJSqt8mTSpoUJwHHrxLbUnYkodwzSaQkKmEnqxEXunEh7A5XIMQzSafG88Xtqgl3SOnFR6862HLQhKrBy1aSDZHl8OutbTIpYd7bloA1RgbkTpKPcRhJXEQNplmw5aENUYO6UO3gbSTxJasrc8CRKzJ0gXeSsztTmW9KacpE3PIks4U6QLmpWZ3LzLemfPoq64UlkEXdq0jYdn47C5OZbGn/6YJ2eKBF3MumkWZ2pkoPJMk1R//RBVCJ2BOm8AmDjBQEA8PTmcK9nsuRgMlCypkxknPkgnXcAjPN6JjtDTAZK1pSJjDNfk8675hrn9UyWHLI+JBPm9RmUiYwxn0nnHQDjvJ7JkoMNl8wSkTHmg3RaATBsXTvO6+VRcvBaf9xSEIcaEZWG+SCdRgCMEszivF7WtVnd+p99KHotnEONiEpFlFLBPxVSd3e3OnToUPRfmPSP9N/8iOam6itq3Rxpv17adOvXEqBvJNpz6f5dEJFxInJYKdXt9T3zG4dA8s2pqHVm2zbDotbf/UozRT7+TkQzmC93pKHohy5066wujF6aKfq/CyJqUo4gXfRDF7r13/Fo9Fp40f9dEFETO8odSZnuJU4qaP1R3kfR/10QURM7Ng6JiBzmt3FYjnIHEVFJMUgTEVmMQZqIyGIM0o14nJqILFOO7o405HGbOBFRRMyk63ibOBFZiEG6jsepichCDNJ1PE5NRBZikK7jcWoispBvkBaRuSLyjIgcE5EfiYjktbDc8T4/IrJQUHfHZwCcVErdKSLPAPgYgL3ZL8sQ20aYEpHzgsodawH8Yurr/QBuy3Y5RETUKChIfwjAmamvzwJY2PoDIrJZRA6JyKHTp0+nvT4iIqcFBem3AXRMfd0x9c9NlFI7lFLdSqnuRYsWpb0+IiKnBQXpfQDWT329FsBz2S6HiIgaBQXpHwPoEpEhAO+gFrSJiCgnvt0dSqnzAO7MaS35su3GcCIiD24OWOIwJSIqCDdPHHKYEhEVhJtBmsOUiKgg3AzSHKZERAXhZpDmMCUiKgg3gzSHKRFRQbjZ3QFwmBIRFYKbmTQRUUEwSBMRWYxBmojIYgzSREQWY5AmIrIYgzQRkcUYpImILCZKqfSeTOQ0gDdSe8J8XAqPG2dKgO+rWMr4vsr4noBs3teVSinPq61SDdJFJCKHlFLdpteRNr6vYinj+yrjewLyf18sdxARWYxBmojIYgzSwA7TC8gI31exlPF9lfE9ATm/L+dr0kRENmMmTURkMQZpIiKLORukRWSuiDwjIsdE5EciIqbXlAapeUxEXhKRn4lIqWaGi8gDIvJL0+tIk4h8UUSeF5FnRWS26fUkJSLzRWS3iBwUkW+YXk8aRKQiIj+f+jrX2OFskAbwGQAnlVIfBbAAwMcMryctqwHMUkrdBOCDANYbXk9qRORKAPebXkeaRGQ5gOuUUmsAPAugDBdtfhrAS0qp1QCuE5EPm15QEiJSBXAYF2NErrHD5SC9FsAvpr7eD+A2g2tJ0x8BDEx9PWZyIRkYALDN9CJStg7AAhE5AGANgNcNrycN5wHMm8ow56Lg/z9USo0qpVYAODn1UK6xw+Ug/SEAZ6a+PgtgocG1pEYp9T9KqZdF5BMAZgPYY3pNaRCR+wAcA/Ab02tJ2SIAp5VSt6KWRd9ieD1p+AmAOwD8FsBrSqn/NbyetOUaO1wO0m8D6Jj6ugMlmjEgIh8H0AvgLqXUBdPrScmdqGWdjwO4QUQ+b3g9aTkL4PjU1ycAdBlcS1q2AfieUupaAAtF5GbTC0pZrrHD5SC9DxfrtWsBPGdwLakRkcUAtgDoUUr9xfR60qKUuk8pdQuATwE4rJT6juk1peQwgBunvr4atUBddB8A8P7U1+cBXGJwLVnINXa4HKR/DKBLRIYAvIPav/gy2ARgCYA9IvKCiPyT6QWRnlLqRQBvi8grAI4rpV42vaYUfBfAZ0XkRQBVlOe/rbpcYwdPHBIRWczlTJqIyHoM0kREFmOQJiKyGIM0EZHFGKSJiCzGIE1EZLH/B1U5gb0+P0c7AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for samples in cls_dict.values():\n",
    "    plot(samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gaussian Mixture Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GM:\n",
    "    def __init__(self, k, max_iters=100):\n",
    "        self.k = k\n",
    "        self.max_iters = max_iters\n",
    "        \n",
    "    def fit(self, x):\n",
    "        models = self._init_models()\n",
    "        for i in range(self.max_iters):\n",
    "            # E step\n",
    "            probs_matrix = []\n",
    "            for sample in x:\n",
    "                probs = [self._get_probs(sample, model) for model in models]\n",
    "                probs_matrix.append(probs)\n",
    "            probs_matrix = np.array(probs_matrix)\n",
    "            # M step\n",
    "            new_models = []\n",
    "            for i in range(len(models)):\n",
    "                new_models.append(self._update_model(probs_matrix[:, i:i + 1], x))\n",
    "            \n",
    "            models = new_models\n",
    "        centers = self._get_centers(probs_matrix, models)\n",
    "        return centers, probs_matrix\n",
    "    \n",
    "    def _init_models(self):\n",
    "        models = []\n",
    "        for i in range(self.k):\n",
    "            mean = np.array([1., 1.])\n",
    "            cov = np.array([[1., 0.], [0., 1.]])\n",
    "            models.append((mean, cov))\n",
    "        return models\n",
    "    \n",
    "    def _get_probs(self, sample, model):\n",
    "        mean, cov = model\n",
    "        return np.exp(-0.5 * (sample - mean).T.dot(np.linalg.inv(cov + 1e-8 * np.eye(len(cov)))).dot(sample - mean)) / ((2 * np.pi) ** 2 * np.linalg.det(cov)) ** 0.5\n",
    "\n",
    "    def _update_model(self, p, x):\n",
    "        mean = np.mean(p * x, 0)\n",
    "        cov = np.cov(p * x, rowvar=0)\n",
    "        return (mean, cov)\n",
    "    \n",
    "    def _get_centers(self, probs, models):\n",
    "        return np.array([models[np.argmax(p)][0] for p in probs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/abbo/anaconda3/envs/lab/lib/python3.6/site-packages/ipykernel_launcher.py:34: RuntimeWarning: invalid value encountered in double_scalars\n"
     ]
    }
   ],
   "source": [
    "gm = GM(k=3)\n",
    "centers, _ = gm.fit(data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
